In [ ]:
import os
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import torchvision.transforms as transforms
from torchvision import datasets, models, transforms
#from warmup_scheduler import GradualWarmupScheduler 

from torch.utils.tensorboard import SummaryWriter
import shutil
#import cv2

plt.style.use('seaborn')
import seaborn as sns
sns.set_style("whitegrid", {'axes.grid' : False})

import data_loaders as dl


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
Out[ ]:
device(type='cuda')

Assignment 5

  • a) Write Convolutional Variational Autoencoder for SVHN dataset
    • Model:
      • Use Conv. layers for encoder and TransposedConv. layers
      • You are allowed to use one FC-layer in each module for the bottleneck, but it's not necessary

Original dataset has weird labeling: images with 0 have label 10. But Pytorch has proper implementation of this dataset (0 is 0). So we have not to care about it

In [ ]:
mean = np.array([0.4914, 0.4822, 0.4465])
std = np.array([0.2470, 0.2435, 0.2616])

data_transforms = {
    'train':transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean, std)
                    ]),
    'val': transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean, std)
                    ])
}

train_loader, valid_loader = dl.load_train_data('svhn', 634, train_transf=data_transforms["train"], test_transf=data_transforms["val"], use_cutmix=False)
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat
In [ ]:
def show_n_samples_from_batch(sample_imgs, sample_labels, n_samples=6):
    """ Plotting n_samples from a batch of images """
    fig, ax = plt.subplots(1,n_samples)
    fig.set_size_inches(3 * n_samples, 3)

    ids = np.random.randint(low=0, high=len(sample_imgs), size=n_samples)

    for i, n in enumerate(ids):
        img = sample_imgs[n]

        # because of normalization of dataset images have strange range, let's fix it with normalization to range [0, 1] for adequate visualization
        img = img.clone().detach()
        img += np.abs(img.min())
        img /= img.max()
        label = f"{str(sample_labels[n].numpy())} "
        ax[i].imshow(img.permute(1,2,0))
        ax[i].set_title(f"Label: {label}")
        ax[i].axis("off")
    plt.show()

train_batch_aug = next(iter(train_loader))
print(train_batch_aug[0].shape)
sample_imgs_aug, sample_labels_aug = train_batch_aug

#helpers.show_n_samples_from_batch(sample_imgs, sample_labels, 8)
show_n_samples_from_batch(sample_imgs_aug, sample_labels_aug, 10)
torch.Size([634, 3, 32, 32])
In [ ]:
train_loader.dataset.labels.max()
Out[ ]:
9
In [ ]:
if not os.path.exists("imgs/conv_vae"):
    os.makedirs("imgs/conv_vae")

if not os.path.exists("imgs/vae"):
    os.makedirs("imgs/vae")

if not os.path.exists("imgs/cvae"):
    os.makedirs("imgs/cvae")
In [ ]:
def save_model(model, optimizer, epoch, stats, model_name):
    """ Saving model checkpoint """
    
    if(not os.path.exists("models")):
        os.makedirs("models")
    if(not os.path.exists(f"models/{model_name}")):
        os.makedirs(f"models/{model_name}")
    savepath = f"models/{model_name}/checkpoint_epoch_{epoch}.pth"

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'stats': stats
    }, savepath)
    return


def load_model(model, optimizer, savepath):
    """ Loading pretrained checkpoint """
    
    checkpoint = torch.load(savepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint["epoch"]
    stats = checkpoint["stats"]
    
    return model, optimizer, epoch, stats


def add_noise(img, mean=0, sigma=0.3):
    """ Adding AWGN to images"""
    noisy_img = img + torch.normal(mean * torch.ones(img.shape), sigma)
    return noisy_img.clamp(0,1)
In [ ]:
def train_epoch(model, train_loader, optimizer, criterion, epoch, device, isConditional=False, classNum=10):
    """ Training a model for one epoch """
    
    loss_list = []
    recons_loss = []
    vae_loss = []
    
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    for i, (images, labels) in progress_bar:
        images = images.to(device)
        if(isConditional):
            labels = labels.to(device)
            labels = F.one_hot(labels, classNum).float()
        
        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()
         
        # Forward pass
        if(not isConditional):
            recons, (z, mu, log_var) = model(images)
        else:
            recons, (z, mu, log_var) = model(images, labels)
         
        # Calculate Loss
        loss, (mse, kld) = criterion(recons, images, mu, log_var)
        loss_list.append(loss.item())
        recons_loss.append(mse.item())
        vae_loss.append(kld.item())
        
        # Getting gradients w.r.t. parameters
        loss.backward()
         
        # Updating parameters
        optimizer.step()
        
        progress_bar.set_description(f"Epoch {epoch+1} Iter {i+1}: loss {loss.item():.5f}. ")
        
    mean_loss = np.mean(loss_list)
    
    return mean_loss, loss_list


@torch.no_grad()
def eval_model(model, eval_loader, criterion, device, epoch=None, savefig=False, savepath="", isConditional=False, classNum=10):
    """ Evaluating the model for either validation or test """
    loss_list = []
    recons_loss = []
    kld_loss = []
    
    for i, (images, labels) in enumerate(eval_loader):
        images = images.to(device)
        if(isConditional):
            labels = labels.to(device)
            labels = F.one_hot(labels, classNum).float()

        
        # Forward pass 
        if(not isConditional):
            recons, (z, mu, log_var) = model(images)
        else:
            recons, (z, mu, log_var) = model(images, labels)
        loss, (mse, kld) = criterion(recons, images, mu, log_var)
        loss_list.append(loss.item())
        recons_loss.append(mse.item())
        kld_loss.append(kld.item())
        
        if(i==0 and savefig):
            save_image( recons[:64].cpu(), os.path.join(savepath, f"recons{epoch}.png") )
            
    # Total correct predictions and loss
    loss = np.mean(loss_list)
    recons_loss = np.mean(recons_loss)
    kld_loss = np.mean(kld_loss)
    return loss, recons_loss, kld_loss


def train_model(model, optimizer, scheduler, criterion, train_loader,
                valid_loader, num_epochs, savepath, save_frequency=5, isConditional=False, model_name="default"):
    """ Training a model for a given number of epochs"""
    
    train_loss = []
    val_loss =  []
    val_loss_recons =  []
    val_loss_kld =  []
    loss_iters = []
    
    for epoch in range(num_epochs):
           
        # validation epoch
        model.eval()  # important for dropout and batch norms
        log_epoch = (epoch % 5 == 0 or epoch == num_epochs - 1)
        loss, recons_loss, kld_loss = eval_model(
                model=model, eval_loader=valid_loader, criterion=criterion,
                device=device, epoch=epoch, savefig=log_epoch, savepath=savepath, isConditional=isConditional
            )
        val_loss.append(loss)
        val_loss_recons.append(recons_loss)
        val_loss_kld.append(kld_loss)
        
        # training epoch
        model.train()  # important for dropout and batch norms
        mean_loss, cur_loss_iters = train_epoch(
                model=model, train_loader=train_loader, optimizer=optimizer,
                criterion=criterion, epoch=epoch, device=device, isConditional=isConditional
            )
        
        # PLATEAU SCHEDULER
        scheduler.step(val_loss[-1])
        train_loss.append(mean_loss)
        loss_iters = loss_iters + cur_loss_iters
        
        if(epoch % save_frequency == 0):
            stats = {
                "train_loss": train_loss,
                "valid_loss": val_loss,
                "loss_iters": loss_iters
            }
            save_model(model=model, optimizer=optimizer, epoch=epoch, stats=stats, model_name=model_name)
        
        if(log_epoch):
            print(f"    Train loss: {round(mean_loss, 5)}")
            print(f"    Valid loss: {round(loss, 5)}")
            print(f"       Valid loss recons: {round(val_loss_recons[-1], 5)}")
            print(f"       Valid loss KL-D:   {round(val_loss_kld[-1], 5)}")
    
    print(f"Training completed")
    return train_loss, val_loss, loss_iters, val_loss_recons, val_loss_kld


def smooth(f, K=5):
    """ Smoothing a function using a low-pass filter (mean) of size K """
    kernel = np.ones(K) / K
    f = np.concatenate([f[:int(K//2)], f, f[int(-K//2):]])  # to account for boundaries
    smooth_f = np.convolve(f, kernel, mode="same")
    smooth_f = smooth_f[K//2: -K//2]  # removing boundary-fixes
    return smooth_f


def set_random_seed(random_seed=None):
    """
    Using random seed for numpy and torch
    """
    if(random_seed is None):
        random_seed = CONFIG["random_seed"]
    os.environ['PYTHONHASHSEED'] = str(random_seed)
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    return


def count_model_params(model):
    """ Counting the number of learnable parameters in a nn.Module """
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return num_params
In [ ]:
def normalize_img(img):
    # normalize image according to the mean and std of the dataset
    img = img.permute(1,2,0)
    img = img * torch.tensor(std).view(1,1,3) + torch.tensor(mean).view(1,1,3)
    img = img.clip(0, 1)
    return img

In [ ]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class UnFlatten(nn.Module):
    def forward(self, input, size=256):
        return input.view(input.size(0), size, 1, 1)

class VAE(nn.Module):
    def __init__(self, image_channels=3, h_dim=256, latent_size=10):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2),
            nn.ReLU(),
            Flatten()
        )
        
        self.fc1 = nn.Linear(h_dim, latent_size)
        self.fc2 = nn.Linear(h_dim, latent_size)
        self.fc3 = nn.Linear(latent_size, h_dim)
        
        self.decoder = nn.Sequential(
            UnFlatten(),
            nn.ConvTranspose2d(h_dim, 128, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=4, stride=2),
            nn.Sigmoid(),
        )
        
    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_().to(device)
        mu = mu.to(device)
        # return torch.normal(mu, std)
        esp = torch.randn(*mu.size()).to(device)
        z = mu + std * esp
        return z
    
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

    def encode(self, x):
        
        h = self.encoder(x)
        z, mu, logvar = self.bottleneck(h)
        return z, mu, logvar

    def decode(self, z):
        z = self.fc3(z)
        z = self.decoder(z)
        return z

    def forward(self, x):
        z, mu, logvar = self.encode(x)
        x_hat = self.decode(z)
        return x_hat, (z, mu, logvar)

def loss_fn(recon_x, x, mu, logvar):
    #BCE = F.binary_cross_entropy(recon_x, x, size_average=False)
    recons_loss = F.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

    return recons_loss + KLD, (recons_loss, KLD)
In [ ]:
vae = VAE(latent_size=12).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
criterion = loss_fn

train_los_vae, val_los_vae, los_iters_vae, val_los_recons_vae, val_los_kld_vae = train_model(
    model=vae, optimizer=optimizer, scheduler=scheduler, criterion=criterion, train_loader=train_loader, valid_loader=valid_loader, num_epochs=45, savepath="imgs/vae", save_frequency=5,
    model_name="vae"   
)
Epoch 1 Iter 115: loss 929579.68750. : 100%|██████████| 115/115 [00:04<00:00, 28.64it/s] 
    Train loss: 1093294.66033
    Valid loss: 2155940.9096
       Valid loss recons: 2155940.9096
       Valid loss KL-D:   0.00124
Epoch 2 Iter 115: loss 874473.18750. : 100%|██████████| 115/115 [00:05<00:00, 21.94it/s]
Epoch 3 Iter 115: loss 908051.18750. : 100%|██████████| 115/115 [00:05<00:00, 20.10it/s]
Epoch 4 Iter 115: loss 869255.87500. : 100%|██████████| 115/115 [00:05<00:00, 21.22it/s]
Epoch 5 Iter 115: loss 871465.68750. : 100%|██████████| 115/115 [00:07<00:00, 15.58it/s]
Epoch 6 Iter 115: loss 830616.56250. : 100%|██████████| 115/115 [00:07<00:00, 16.40it/s]
    Train loss: 833628.88424
    Valid loss: 993481.32506
       Valid loss recons: 993415.38504
       Valid loss KL-D:   65.94064
Epoch 7 Iter 115: loss 861443.75000. : 100%|██████████| 115/115 [00:05<00:00, 19.66it/s]
Epoch 8 Iter 115: loss 888703.25000. : 100%|██████████| 115/115 [00:07<00:00, 14.96it/s]
Epoch 9 Iter 115: loss 817095.25000. : 100%|██████████| 115/115 [00:06<00:00, 16.84it/s]
Epoch 10 Iter 115: loss 824744.37500. : 100%|██████████| 115/115 [00:07<00:00, 15.57it/s]
Epoch 11 Iter 115: loss 865907.12500. : 100%|██████████| 115/115 [00:09<00:00, 11.99it/s]
    Train loss: 821817.35272
    Valid loss: 978109.61272
       Valid loss recons: 978065.47545
       Valid loss KL-D:   44.13595
Epoch 12 Iter 115: loss 758271.37500. : 100%|██████████| 115/115 [00:07<00:00, 15.69it/s]
Epoch 13 Iter 115: loss 864348.93750. : 100%|██████████| 115/115 [00:07<00:00, 15.49it/s]
Epoch 14 Iter 115: loss 785289.62500. : 100%|██████████| 115/115 [00:08<00:00, 13.63it/s]
Epoch 15 Iter 115: loss 822915.62500. : 100%|██████████| 115/115 [00:11<00:00, 10.13it/s]
Epoch 16 Iter 115: loss 823016.12500. : 100%|██████████| 115/115 [00:08<00:00, 14.33it/s]
    Train loss: 818690.24185
    Valid loss: 974625.87779
       Valid loss recons: 974595.86161
       Valid loss KL-D:   30.01625
Epoch 17 Iter 115: loss 816395.93750. : 100%|██████████| 115/115 [00:08<00:00, 14.05it/s]
Epoch 18 Iter 115: loss 829197.81250. : 100%|██████████| 115/115 [00:10<00:00, 10.48it/s]
Epoch 19 Iter 115: loss 845779.81250. : 100%|██████████| 115/115 [00:08<00:00, 13.18it/s]
Epoch 20 Iter 115: loss 803577.12500. : 100%|██████████| 115/115 [00:09<00:00, 11.67it/s]
Epoch 21 Iter 115: loss 812549.56250. : 100%|██████████| 115/115 [00:09<00:00, 12.45it/s]
    Train loss: 816212.05054
    Valid loss: 972471.4774
       Valid loss recons: 972447.58445
       Valid loss KL-D:   23.89596
Epoch 22 Iter 115: loss 832423.25000. : 100%|██████████| 115/115 [00:07<00:00, 14.80it/s]
Epoch 23 Iter 115: loss 815158.68750. : 100%|██████████| 115/115 [00:07<00:00, 15.82it/s]
Epoch 24 Iter 115: loss 811104.56250. : 100%|██████████| 115/115 [00:08<00:00, 14.00it/s]
Epoch 25 Iter 115: loss 817898.87500. : 100%|██████████| 115/115 [00:11<00:00,  9.70it/s]
Epoch 26 Iter 115: loss 795722.43750. : 100%|██████████| 115/115 [00:07<00:00, 14.81it/s]
    Train loss: 814434.66033
    Valid loss: 971056.20703
       Valid loss recons: 971035.87798
       Valid loss KL-D:   20.32735
Epoch 27 Iter 115: loss 841175.68750. : 100%|██████████| 115/115 [00:09<00:00, 12.62it/s]
Epoch 28 Iter 115: loss 835500.31250. : 100%|██████████| 115/115 [00:10<00:00, 11.10it/s]
Epoch 29 Iter 115: loss 794173.31250. : 100%|██████████| 115/115 [00:08<00:00, 13.18it/s]
Epoch 30 Iter 115: loss 804337.00000. : 100%|██████████| 115/115 [00:09<00:00, 12.15it/s]
Epoch 31 Iter 115: loss 820416.62500. : 100%|██████████| 115/115 [00:08<00:00, 14.31it/s]
    Train loss: 813109.89293
    Valid loss: 970132.38291
       Valid loss recons: 970114.29371
       Valid loss KL-D:   18.09056
Epoch 32 Iter 115: loss 865708.06250. : 100%|██████████| 115/115 [00:08<00:00, 13.84it/s]
Epoch 33 Iter 115: loss 828241.25000. : 100%|██████████| 115/115 [00:13<00:00,  8.64it/s]
Epoch 34 Iter 115: loss 813692.81250. : 100%|██████████| 115/115 [00:08<00:00, 13.50it/s]
Epoch 35 Iter 115: loss 792860.06250. : 100%|██████████| 115/115 [00:08<00:00, 13.77it/s]
Epoch 36 Iter 115: loss 777841.87500. : 100%|██████████| 115/115 [00:08<00:00, 13.40it/s]
    Train loss: 811586.27065
    Valid loss: 969411.65281
       Valid loss recons: 969395.84933
       Valid loss KL-D:   15.8045
Epoch 37 Iter 115: loss 832186.68750. : 100%|██████████| 115/115 [00:08<00:00, 13.13it/s]
Epoch 38 Iter 115: loss 755650.25000. : 100%|██████████| 115/115 [00:15<00:00,  7.66it/s]
Epoch 39 Iter 115: loss 872910.37500. : 100%|██████████| 115/115 [00:06<00:00, 16.65it/s]
Epoch 40 Iter 115: loss 820688.18750. : 100%|██████████| 115/115 [00:07<00:00, 14.51it/s]
Epoch 41 Iter 115: loss 839223.18750. : 100%|██████████| 115/115 [00:13<00:00,  8.84it/s]
    Train loss: 810818.1712
    Valid loss: 968845.45601
       Valid loss recons: 968830.35565
       Valid loss KL-D:   15.10062
Epoch 42 Iter 115: loss 849911.31250. : 100%|██████████| 115/115 [00:08<00:00, 13.97it/s]
Epoch 43 Iter 115: loss 826577.06250. : 100%|██████████| 115/115 [00:08<00:00, 13.02it/s]
Epoch 44 Iter 115: loss 807671.93750. : 100%|██████████| 115/115 [00:08<00:00, 14.36it/s]
Epoch 45 Iter 115: loss 806533.37500. : 100%|██████████| 115/115 [00:08<00:00, 13.76it/s]
    Train loss: 809494.62989
    Valid loss: 968470.94671
       Valid loss recons: 968456.9202
       Valid loss KL-D:   14.02405
Training completed

In [ ]:
imgs, _ = next(iter(valid_loader)) 

vae.eval()
with torch.no_grad():
    recons, _ = vae(imgs.to(device))
    
fig, ax = plt.subplots(2, 11)
fig.set_size_inches(18, 5)
for i in range(11):
    img = imgs[i+15]
    img = normalize_img(img)
    recon = recons[i+15]
    recon = recon.cpu().permute(1,2,0)
    ax[0, i].imshow(img)
    ax[0, i].axis("off")
    ax[1, i].imshow(recon)
    ax[1, i].axis("off")

ax[0, 5].set_title("Original Image")
ax[1, 5].set_title("Reconstruction")
plt.tight_layout()
plt.show()
In [ ]:
def plot_metrics(loss_iters, train_loss, val_loss, val_loss_recons, val_loss_kld):
    filtered_loss_iters = np.array(loss_iters)
    med = np.median(filtered_loss_iters)
    filtered_loss_iters[loss_iters < med / 2] = med

    plt.style.use('seaborn')
    fig, ax = plt.subplots(1,4)
    fig.set_size_inches(30,5)

    smooth_loss = smooth(filtered_loss_iters, 31)
    ax[0].plot(filtered_loss_iters, c="blue", label="Loss", linewidth=3, alpha=0.5)
    ax[0].plot(smooth_loss, c="red", label="Smoothed Loss", linewidth=3, alpha=1)
    ax[0].legend(loc="best")
    ax[0].set_xlabel("Iteration")
    ax[0].set_ylabel("CE Loss")
    ax[0].set_yscale("log")
    ax[0].set_title("Training Progress")

    smooth_loss = smooth(filtered_loss_iters, 31)
    START = 500
    N_ITERS = len(filtered_loss_iters)
    ax[1].plot(np.arange(START, N_ITERS), filtered_loss_iters[START:], c="blue", label="Loss", linewidth=3, alpha=0.5)
    ax[1].plot(np.arange(START, N_ITERS), smooth_loss[START:], c="red", label="Smoothed Loss", linewidth=3, alpha=1)
    ax[1].legend(loc="best")
    ax[1].set_xlabel("Iteration")
    ax[1].set_ylabel("Loss")
    ax[1].set_yscale("log")
    ax[1].set_title(f"Training Progress from Iter {START}")

    epochs = np.arange(len(train_loss)) + 1
    ax[2].plot(epochs[1:], train_loss[1:], c="red", label="Train Loss", linewidth=3)
    ax[2].plot(epochs[1:], val_loss[1:], c="blue", label="Valid Loss", linewidth=3)
    ax[2].legend(loc="best")
    ax[2].set_xlabel("Epochs")
    ax[2].set_ylabel("Loss")
    ax[2].set_title("Loss Curves")

    epochs = np.arange(len(val_loss)) + 1
    ax[3].plot(epochs[1:], val_loss[1:], c="blue", label="Valid Loss Total", linewidth=3)
    ax[3].plot(epochs[1:], val_loss_recons[1:], c="green", label="Recons. Loss", linewidth=2)
    ax[3].plot(epochs[1:], val_loss_kld[1:], c="purple", label="KLD Loss", linewidth=2)
    ax[3].legend(loc="best")
    ax[3].set_xlabel("Epochs")
    ax[3].set_ylabel("Loss")
    ax[3].set_yscale("log")
    ax[3].set_title("Independent Loss Curves")

    plt.show()
In [ ]:
plot_metrics(los_iters_vae, train_los_vae, val_los_vae, val_los_recons_vae, val_los_kld_vae)
  • b) Implement a Conditional Variational Autoencoder, which generates images based on a given class. Show the capabilities of your model.

In [ ]:
class CVAE(nn.Module):
    def __init__(self, input_size, latent_size, class_size):
        super(CVAE, self).__init__()
        self.input_size = input_size
        self.class_size = class_size
        self.latent_size = latent_size
        self.units = 400
        self.encode1 = nn.Linear(input_size + self.class_size, self.units)
        self.encode2 = nn.Linear(self.units, self.units//2)
        self.encode3 = nn.Linear(self.units//2, latent_size)
        self.encode4 = nn.Linear(self.units//2, latent_size)
        self.decode1 = nn.Linear(latent_size + self.class_size, self.units//2)
        self.decode2 = nn.Linear(self.units//2, self.units)
        self.decode3 = nn.Linear(self.units, self.input_size)



    def encoding_model(self, x, c):
        theinput = torch.cat((x.float(), c.float()), 1)
        output = self.encode1(theinput)
        output = self.encode2(output)
        mu = self.encode3(output)
        logvar = self.encode4(output)
        return mu, logvar

    def decoding_model(self, z, c):
        z_input = torch.cat((z.float(), c.float()), 1)
        output = self.decode1(z_input)
        output = self.decode2(output)
        x_hat = self.decode3(output)
        return x_hat

    def forward(self, x, c):
        x = x.view(-1, 32*32*3)
        #c = c.view(-1, 10*3)
        mu, logvar = self.encoding_model(x, c)
        z = self.reparametrize(mu, logvar)
        x_hat = self.decoding_model(z, c)
        x_hat = x_hat.view(-1, 3, 32, 32)
        return x_hat, (z, mu, logvar)
   
    def reparametrize(self, mu, logvar):
        # std = logvar.mul(0.5).exp_()
        # epsilon = Variable(std.data.new(std.size()).normal_())
        # return epsilon.mul(std) + mu

        """ Reparametrization trick"""
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)  # random sampling happens here
        z = mu + std * eps
        return z


def cvae_loss(recon_x, x, mu, logvar, lambda_kld=1e-3):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    # BCE = nn.MSELoss(reduction='sum')(recon_x, x)
    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

    return recon_loss + KLD, (recon_loss, KLD)
In [ ]:
latent_size = 12
cvae = CVAE(32*32*3, latent_size, 10).to(device)
cvae_optimizer = torch.optim.Adam(cvae.parameters(), lr=1e-4)
cvae_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(cvae_optimizer, mode='min', factor=0.5, patience=5, verbose=True)


train_loss_cvae, val_loss_cvae, los_iters_cvae, val_loss_recons_cvae, val_loss_kld_cvae = train_model(
    model=cvae, optimizer=cvae_optimizer, scheduler=cvae_scheduler, criterion=cvae_loss, train_loader=train_loader, valid_loader=valid_loader, num_epochs=45, savepath="imgs/cvae", save_frequency=5,
    isConditional=True, model_name="cvae"
)
Epoch 1 Iter 115: loss 352643.56250. : 100%|██████████| 115/115 [00:05<00:00, 21.97it/s]
    Train loss: 608324.36549
    Valid loss: 1611716.83389
       Valid loss recons: 1611716.83333
       Valid loss KL-D:   0.02069
Epoch 2 Iter 115: loss 267606.00000. : 100%|██████████| 115/115 [00:07<00:00, 16.03it/s]
Epoch 3 Iter 115: loss 221141.18750. : 100%|██████████| 115/115 [00:07<00:00, 16.05it/s]
Epoch 4 Iter 115: loss 202758.31250. : 100%|██████████| 115/115 [00:06<00:00, 16.55it/s]
Epoch 5 Iter 115: loss 215614.67188. : 100%|██████████| 115/115 [00:06<00:00, 16.75it/s]
Epoch 6 Iter 115: loss 200854.75000. : 100%|██████████| 115/115 [00:07<00:00, 14.70it/s]
    Train loss: 200254.07269
    Valid loss: 208656.89414
       Valid loss recons: 208591.78006
       Valid loss KL-D:   65.11498
Epoch 7 Iter 115: loss 199607.14062. : 100%|██████████| 115/115 [00:09<00:00, 12.15it/s]
Epoch 8 Iter 115: loss 203851.34375. : 100%|██████████| 115/115 [00:07<00:00, 15.40it/s]
Epoch 9 Iter 115: loss 196581.57812. : 100%|██████████| 115/115 [00:08<00:00, 14.09it/s]
Epoch 10 Iter 115: loss 206220.96875. : 100%|██████████| 115/115 [00:11<00:00,  9.99it/s]
Epoch 11 Iter 115: loss 191991.56250. : 100%|██████████| 115/115 [00:08<00:00, 13.30it/s]
    Train loss: 198149.61223
    Valid loss: 204916.92694
       Valid loss recons: 204822.01795
       Valid loss KL-D:   94.90902
Epoch 12 Iter 115: loss 203066.54688. : 100%|██████████| 115/115 [00:07<00:00, 14.56it/s]
Epoch 13 Iter 115: loss 204003.01562. : 100%|██████████| 115/115 [00:08<00:00, 13.38it/s]
Epoch 14 Iter 115: loss 192735.09375. : 100%|██████████| 115/115 [00:06<00:00, 16.45it/s]
Epoch 15 Iter 115: loss 201117.01562. : 100%|██████████| 115/115 [00:11<00:00,  9.84it/s]
Epoch 16 Iter 115: loss 203098.20312. : 100%|██████████| 115/115 [00:07<00:00, 14.85it/s]
    Train loss: 197457.05217
    Valid loss: 204370.42653
       Valid loss recons: 204281.39811
       Valid loss KL-D:   89.02765
Epoch 17 Iter 115: loss 194368.82812. : 100%|██████████| 115/115 [00:08<00:00, 13.91it/s]
Epoch 18 Iter 115: loss 204396.45312. : 100%|██████████| 115/115 [00:12<00:00,  9.58it/s]
Epoch 19 Iter 115: loss 194369.39062. : 100%|██████████| 115/115 [00:06<00:00, 16.80it/s]
Epoch 20 Iter 115: loss 196466.40625. : 100%|██████████| 115/115 [00:08<00:00, 14.32it/s]
Epoch 21 Iter 115: loss 198856.18750. : 100%|██████████| 115/115 [00:11<00:00,  9.86it/s]
    Train loss: 197205.90707
    Valid loss: 203893.8604
       Valid loss recons: 203821.49082
       Valid loss KL-D:   72.36965
Epoch 22 Iter 115: loss 208106.29688. : 100%|██████████| 115/115 [00:07<00:00, 15.07it/s]
Epoch 23 Iter 115: loss 200167.92188. : 100%|██████████| 115/115 [00:08<00:00, 13.48it/s]
Epoch 24 Iter 115: loss 200608.01562. : 100%|██████████| 115/115 [00:11<00:00,  9.99it/s]
Epoch 25 Iter 115: loss 198296.46875. : 100%|██████████| 115/115 [00:08<00:00, 14.01it/s]
Epoch 26 Iter 115: loss 189027.68750. : 100%|██████████| 115/115 [00:10<00:00, 10.79it/s]
    Train loss: 196952.7519
    Valid loss: 203659.08598
       Valid loss recons: 203602.42406
       Valid loss KL-D:   56.66177
Epoch 27 Iter 115: loss 212070.79688. : 100%|██████████| 115/115 [00:07<00:00, 14.45it/s]
Epoch 28 Iter 115: loss 202171.03125. : 100%|██████████| 115/115 [00:08<00:00, 13.16it/s]
Epoch 29 Iter 115: loss 201852.67188. : 100%|██████████| 115/115 [00:14<00:00,  7.72it/s]
Epoch 30 Iter 115: loss 195654.65625. : 100%|██████████| 115/115 [00:07<00:00, 16.18it/s]
Epoch 31 Iter 115: loss 194785.89062. : 100%|██████████| 115/115 [00:07<00:00, 14.73it/s]
    Train loss: 196834.98193
    Valid loss: 203396.84375
       Valid loss recons: 203351.85677
       Valid loss KL-D:   44.98704
Epoch 32 Iter 115: loss 188516.14062. : 100%|██████████| 115/115 [00:10<00:00, 11.33it/s]
Epoch 33 Iter 115: loss 202829.28125. : 100%|██████████| 115/115 [00:08<00:00, 13.76it/s]
Epoch 34 Iter 115: loss 185823.84375. : 100%|██████████| 115/115 [00:07<00:00, 16.14it/s]
Epoch 35 Iter 115: loss 187175.68750. : 100%|██████████| 115/115 [00:06<00:00, 17.20it/s]
Epoch 36 Iter 115: loss 191699.87500. : 100%|██████████| 115/115 [00:07<00:00, 15.10it/s]
Epoch 00036: reducing learning rate of group 0 to 5.0000e-05.
    Train loss: 196769.09226
    Valid loss: 203680.57071
       Valid loss recons: 203643.71003
       Valid loss KL-D:   36.86133
Epoch 37 Iter 115: loss 190602.81250. : 100%|██████████| 115/115 [00:08<00:00, 13.55it/s]
Epoch 38 Iter 115: loss 207870.75000. : 100%|██████████| 115/115 [00:06<00:00, 16.99it/s]
Epoch 39 Iter 115: loss 197547.12500. : 100%|██████████| 115/115 [00:08<00:00, 14.26it/s]
Epoch 40 Iter 115: loss 190238.45312. : 100%|██████████| 115/115 [00:11<00:00, 10.11it/s]
Epoch 41 Iter 115: loss 183063.29688. : 100%|██████████| 115/115 [00:08<00:00, 13.19it/s]
    Train loss: 196386.12011
    Valid loss: 203068.85745
       Valid loss recons: 203034.31115
       Valid loss KL-D:   34.54609
Epoch 42 Iter 115: loss 195143.14062. : 100%|██████████| 115/115 [00:07<00:00, 15.39it/s]
Epoch 43 Iter 115: loss 192820.20312. : 100%|██████████| 115/115 [00:11<00:00, 10.43it/s]
Epoch 44 Iter 115: loss 182292.62500. : 100%|██████████| 115/115 [00:07<00:00, 14.60it/s]
Epoch 45 Iter 115: loss 187787.31250. : 100%|██████████| 115/115 [00:08<00:00, 14.11it/s]
    Train loss: 196447.7447
    Valid loss: 203087.86998
       Valid loss recons: 203054.98765
       Valid loss KL-D:   32.88296
Training completed

In [ ]:
imgs, labels = next(iter(valid_loader)) 
labels = F.one_hot(labels, 10)

cvae.eval()
with torch.no_grad():
    recons, _ = cvae(imgs.to(device), labels.to(device))
    
fig, ax = plt.subplots(2, 11)
fig.set_size_inches(18, 5)
for i in range(11):
    img = imgs[i+15]
    img = normalize_img(img)
    recon = recons[i+15]
    recon = normalize_img(recon.cpu())
    # recon = recon.cpu().permute(1,2,0)
    ax[0, i].imshow(img)
    ax[0, i].axis("off")
    ax[1, i].imshow(recon)
    ax[1, i].axis("off")

ax[0, 5].set_title("Original Image")
ax[1, 5].set_title("Reconstruction")
plt.tight_layout()
plt.show()
In [ ]:
plot_metrics(los_iters_cvae, train_loss_cvae, val_loss_cvae, val_loss_recons_cvae, val_loss_kld_cvae)
In [ ]:
@torch.no_grad()
def generate_samples(model, num_samples=10, class_size=10, c_multiplayer=1 ):
    model.eval()
    samples = []
    for i in range(num_samples):
        z = torch.randn(1, 12).to(device)
        z *= 2
        #z = - F.one_hot(torch.tensor([i]), 20).to(device)
        c = F.one_hot(torch.tensor([i]), class_size).to(device)
        sample = model.decoding_model(z, c*c_multiplayer)
        sample = sample.view(3, 32, 32)
        samples.append(sample)
    return samples

def generateNshow(model, num_samples=10, class_size=10, c_multiplayer=1):

    samples = generate_samples(model, num_samples=num_samples, class_size=class_size, c_multiplayer=c_multiplayer)
    fig, ax = plt.subplots(1, 10)
    fig.set_size_inches(18, 5)
    for i in range(10):
        sample = samples[i]
    
        sample = normalize_img(sample.cpu())
        ax[i].imshow(sample)
        ax[i].axis("off")
        ax[i].set_title(f"{i}")
    plt.show()


generateNshow(cvae, num_samples=10, class_size=10, c_multiplayer=1)

We can see very "weak" numbers which correspond to the labels.
Let's try to multiply the one_hot encoded label vector

In [ ]:
generateNshow(cvae, num_samples=10, class_size=10, c_multiplayer=15)

Now we can see the numbers better
As the latent space is big, the random values in function generateNshow(..) have more impact on the final result. And some images of numbers do not fully correspond to the labels
I have tried it with latent space size = 2. In this case result image depends on the given label much stronger. But the results of reconstruction are weaker.

  • c) Investigate latent space and visualize some interpolations
In [ ]:
@torch.no_grad()
def plot_reconstructed(model, xrange=(-3, 3), yrange=(-2, 2), N=12, conditional=False, cond_vis_mode: ("diag","hori","vert")="diag"):
    """
    Sampling equispaced points from the latent space givent the xange and yrange, 
    decoding latents and visualizing distribution of the space
    """
    SIZE = 32
    grid = np.empty((N*SIZE, N*SIZE, 3))

    
    
    for i, y in enumerate(np.linspace(*yrange, N)):
        for j, x in enumerate(np.linspace(*xrange, N)):
            dummy_values = torch.randn(10)

            z = torch.Tensor([[x*2, y*2, *dummy_values]]).to(device)
            if conditional:
                if cond_vis_mode == "diag":
                    selected_label = int((i+j) / (2*(N-0.9)) * 10) 
                elif cond_vis_mode == "hori":
                    selected_label = int(i / (N-0.9) * 10)
                elif cond_vis_mode == "vert":
                    selected_label = int(j / (N-0.9) * 10)
                c = F.one_hot(torch.tensor([selected_label]), 10)
                # multiply one hot vector by 20 to make it more visible
                c *= 10
                c = c.to(device)
                x_hat = model.decoding_model(z, c).cpu()
            else:
                x_hat = model.decode(z).cpu()
            x_hat = x_hat.view(3, 32,32)
            
            grid[(N-1-i)*SIZE:(N-i)*SIZE, j*SIZE:(j+1)*SIZE] = normalize_img(x_hat)
           
    plt.figure(figsize=(12,20))
    plt.imshow(grid, extent=[*yrange, *xrange], cmap="gray")
    plt.axis("off")

Interpolations of Conv VAE:

As the latent space has 12 dimensions, it's difficult to observe all possible interpolations. But here are few samples of interpolations.

In [ ]:
plot_reconstructed(vae, xrange=(-2, 2), yrange=(-2, 2), N=20, conditional=False)
In [ ]:
plot_reconstructed(vae, xrange=(-2, 2), yrange=(-2, 2), N=20, conditional=False)
In [ ]:
plot_reconstructed(vae, xrange=(-2, 2), yrange=(-2, 2), N=20, conditional=False)

Interpolations of CVAE:

In [ ]:
plot_reconstructed(cvae, xrange=(-2, 2), yrange=(-2, 2), N=20, conditional=True)
In [ ]:
plot_reconstructed(cvae, xrange=(-2, 2), yrange=(-2, 2), N=20, conditional=True, cond_vis_mode="hori")
In [ ]:
plot_reconstructed(cvae, xrange=(-2, 2), yrange=(-2, 2), N=20, conditional=True, cond_vis_mode="vert")

We can observe different variations of average number-representations. Labels are changing in the diagonal/horizontal/vertical directions

Here is an example of interpolations in cvae with latent space size:

For this task we should use the Inception V3 model, which is pretrained on Imagenet dataset.
The use of activations from the Inception V3 model to summarize each image gives the score its name of “Frechet Inception Distance.”

In [ ]:
import torchvision.models
import glob
from torch.utils import data
from PIL import Image

# get inception v3 model 
inseption_model = models.inception_v3(weights=torchvision.models.Inception_V3_Weights.IMAGENET1K_V1)
inseption_model = inseption_model.to(device)
In [ ]:
def get_moments(model, samples):
    model.eval()

    with torch.no_grad():
        X_sum = torch.zeros((1000, 1)).to(device)
        XXT_sum = torch.zeros((1000, 1000)).to(device)
        count = 0

        for inp in tqdm(samples):
          # [B, F]
          pred = model(inp.to(device))
          # [B, F] -> [1, F] -> [F, 1]
          X_sum += pred.sum(dim=0, keepdim=True).T
          # [B, 1, F] x [B, F, 1] -> [B, F, F] -> [F, F]
          XXT_sum += (pred[:, None] * pred[..., None]).sum(0)
          count += len(inp)

        X_mean = X_sum / count
        X_cov = XXT_sum / count - X_mean @ X_mean.T

    return X_mean, X_cov

def frechet_inception_distance(m_w, C_w, m, C, debug=False):
    eigenvals = torch.linalg.eigvals(C @ C_w)
    trace_sqrt_CCw = eigenvals.real.clamp(min=0).sqrt().sum()
    if debug:
        print('Largest imaginary part magnitude:', eigenvals[eigenvals.imag > 0].abs().max().item())
        print('Most negative:', eigenvals[eigenvals.real < 0].real.min().item())
        print()
    fid = ((m - m_w)**2).sum() + C.trace() + C_w.trace() - 2 * trace_sqrt_CCw
    return fid

class FIDDataset(data.Dataset):
    def __init__(self, filelist=None, path=None):
        if filelist is not None:
          self.filelist = filelist
        else:
          self.filelist = glob.glob(f'{path}')
        self.transform = torchvision.models.Inception_V3_Weights.IMAGENET1K_V1.transforms()


    def __len__(self):
        return len(self.filelist)

    def __getitem__(self, idx):
        img = Image.open(self.filelist[idx])
        img = self.transform(img)
        return img
In [ ]:
cvae_recons_file_list = np.random.RandomState(42).permutation(glob.glob(f'./imgs/cvae/*png'))[:1000]
vae_recons_file_list = np.random.RandomState(42).permutation(glob.glob(f'./imgs/vae/*png'))

cvae_recons_dataset = FIDDataset(filelist=cvae_recons_file_list)
vae_recons_dataset = FIDDataset(filelist=vae_recons_file_list)

cvae_recons_loader = data.DataLoader(cvae_recons_dataset, batch_size=100, prefetch_factor=2, shuffle=False)
vae_recons_loader = data.DataLoader(vae_recons_dataset, batch_size=100, shuffle=False)
In [ ]:
m_cvae, C_cvae = get_moments(inseption_model, cvae_recons_loader)
m_vae, C_vae = get_moments(inseption_model, vae_recons_loader)

fid_cvae = frechet_inception_distance(m_vae, C_vae, m_cvae, C_cvae, debug=True)
fid_vae = frechet_inception_distance(m_cvae, C_cvae, m_vae, C_vae, debug=True)

print(f"FID for CVAE: {fid_cvae:{3}.{8}}")
print(f"FID for VAE: {fid_vae:{3}.{8}}")
100%|██████████| 1/1 [00:00<00:00, 19.66it/s]
100%|██████████| 1/1 [00:00<00:00, 19.56it/s]
Largest imaginary part magnitude: 0.0003027394413948059
Most negative: -0.00022941167117096484

Largest imaginary part magnitude: 0.00032907715649344027
Most negative: -0.000238061387790367

FID for CVAE: 171.54089
FID for VAE: 171.53049

Both networks perform very similar. According to the FID score Convolutional Variational Autoencoder performs a bit better than Conditional Variational Autoencoder (smaller values are better)

Performance of both networks depends very much on dimensions/size of latent space.
In this assignment I have used both networks with latent dimentions = 12.

Here is an example of Convolutional VEA reconstruction:

Here is an example of the Conditional VEA with latent space dim = 2:

It's difficult to store a lot of information only in 2D, that's why a lot of reconstructions seem not very clear.

  • Extra point:
    • Train a beta-VAE using ResNet-based encoders and decoders.
    • Encoder is a ResNet-18
    • Decoder is the mirrored version of the encoder
    • Compare this model with the previous ones,
In [ ]: